{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.backends.cudnn as cudnn\n",
    "from munch import Munch\n",
    "from tqdm import tqdm_notebook as tqdm\n",
    "\n",
    "import datasets\n",
    "import models\n",
    "import transforms\n",
    "import utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "config_path = 'pretrained/drn_d_22_OilChange/config.yml'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(config_path, 'r') as f:\n",
    "    cfg = Munch.fromYAML(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=> loaded checkpoint 'pretrained/drn_d_22_OilChange/checkpoint_00000900.pth.tar' (epoch 900)\n"
     ]
    }
   ],
   "source": [
    "model = models.DRNSeg(cfg.arch, cfg.data.classes, None, pretrained=True)\n",
    "model = torch.nn.DataParallel(model).cuda()\n",
    "cudnn.benchmark = True\n",
    "checkpoint = torch.load(cfg.training.resume)\n",
    "model.load_state_dict(checkpoint['state_dict'])\n",
    "model.eval()\n",
    "print(\"=> loaded checkpoint '{}' (epoch {})\".format(cfg.training.resume, checkpoint['epoch']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "val_transforms = transforms.Compose([transforms.ToTensor()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loading annotations into memory...\n",
      "Done (t=0.09s)\n",
      "creating index...\n",
      "index created!\n"
     ]
    }
   ],
   "source": [
    "val_dataset = datasets.Dataset(cfg.data.root, 'val_' + cfg.data.ann_file, 'val', val_transforms)\n",
    "val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=0, pin_memory=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "88215f6a4172416184c3b8d81f8b5515",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=1950), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "accuracy: 99.82\n",
      "mean IoU: 0.9650\n"
     ]
    }
   ],
   "source": [
    "acc = utils.AverageMeter()\n",
    "hist = np.zeros((cfg.data.classes, cfg.data.classes))\n",
    "with torch.no_grad():\n",
    "    for input, target in tqdm(val_loader):\n",
    "        target = target.cuda(non_blocking=True)\n",
    "        output = model(input)\n",
    "        acc.update(utils.accuracy(output, target), input.size(0))\n",
    "        _, pred = output.max(1)\n",
    "        hist += utils.fast_hist(pred.cpu().data.numpy().flatten(), target.cpu().numpy().flatten(), cfg.data.classes)\n",
    "ious = utils.per_class_iou(hist)\n",
    "\n",
    "print('accuracy: {:.2f}'.format(acc.val))\n",
    "print('mean IoU: {:.4f}'.format(ious.mean()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
